
#Copyright 2023 Massachusetts Institute of Technology

#Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

#The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.



#!/usr/bin/env python3
import numpy as np
from gcgridobj import cstools, latlontools, regrid
import os
from datetime import datetime, timedelta
import mcstats
import mortality_v3 as mortality
import netCDF4
import pandas
import time
import pickle

use_tqdm = True
try:
    from tqdm import tqdm
    print_cmd = tqdm.write
except:
    use_tqdm = False
    print_cmd = print

### USER INPUT
# Basic parameters of the simulation
start_date = datetime(2014,8,1,0,0,0)                  # Start date (inclusive)
end_date = datetime(2015,8,1,0,0,0)                    # End date (not inclusive)

# Uncertainty quantification
n_rand = 1400                                             # Number of random draws (Monte Carlo)
rng_seed = 20110615                                    # Seed for the RNG (to ensure reproducibility)
rng_name = 'pseudorandom'                              # If installed, can also use a QRNG (Sobol set)
n_cpus = int(os.getenv('SLURM_CPUS_PER_TASK',1))       # Number of CPUs - hardcode if not in SLURM

# Non-simulation data locations
GPW_dir = '/n/mickley/users/ebonilla/GPWv4pt11'
mort_data_dir = 'Preprocessed_Mort_Data'               # Location of (eg) GPW_to_WHO.pkl
WHO_dir = os.path.join(mort_data_dir,'WHO_data')        # WHO 2016 mortality data
GEMM_file = os.path.join(mort_data_dir,'gemm_crf.csv') # Only needed if GEMM CRFs are used

# Pre-processed simulation data locations
base_dir = 'Data_SA'                                  # Directory containing simulation output (processed)
# This variable contains a dictionary of all the simulations you want to compare to the baseline,
# with one entry per pre-processed netCDF file in base_dir
sim_data = {'base':     'GEOSChem.PM25.2014mean.nc4',
            'no_fires': 'GEOSChem.PM25.2014mean_noGFED.nc4',
            'fires_only':'GEOSChem.PM25.2014mean_firesonly.nc4',
            'pm25_0':    'GEOSChem.PM25_0.nc4' }
ref_map = 'base'                                       # Tell the routine which scenario is acting as the reference

# Where to write output data
output_file = 'mortality_output_dump_2014.pkl'

# Grid information
common_chi_grid = True                                 # Same grid for all simulations?
hrz_grid_chi = None                                    # Horizontal grid used by simulation. Set to None to autoidentify
is_GCHP = False                                         # If autoidentifying - is this a cubed sphere simulation?

# Exposure response functions to consider - remove those not needed/wanted
ERF_set = [mortality.erf_lib['chen_hoek_2020_AC'  ],   # Chen and Hoek 2020 all-cause for PM2.5
           mortality.erf_lib['hoek_meta_2013_CV'  ],   # Hoek 2013 metanalysis cardiovascular for PM2.5
           mortality.erf_lib['epa_2011_AC'        ],   # EPA expert solicitation all-cause for PM2.5
           mortality.erf_lib['krewski_2009_AC'    ],   # Krewski et al. 2009 all-cause for PM2.5
           mortality.erf_lib['vodonos_2018_AC'    ],]   # Vodonos et al 2018 all-cause for PM2.5
          # mortality.erf_lib['turner_2015_RD'     ],]  # Turner et al 2015 respiratory disease for ozone -EXB:removed turner bc we do not have O3 in our data-
use_GEMM = False                                       # Do we want to add the GEMM for PM2.5?

### START OF MAIN CODE

print('Running with {:d} CPUs'.format(n_cpus))         # Verify you have the number you think!
if n_rand < 1000:
    print('WARNING: Low number of calculations ({:d}, < 1000). You may wish to consider a higher number and/or check convergence'.format(n_rand))

# Add the GEMM data if requested
if use_GEMM:
    assert os.path.isfile(GEMM_file), 'GEMM data not found at {}'.format(GEMM_file)
    GEMM_full = mortality.gen_GEMM_erfs(GEMM_file)
    
    for ERF in GEMM_full:
         if combined_only:
             if 'NCD+LRI' in ERF.disease_name:
                 ERF_set.append(ERF)
         elif exclude_combined:
             # Remove the "combined" ERFs
             if 'NCD+LRI' not in ERF.disease_name:
                 ERF_set.append(ERF)
         else:
             ERF_set.append(ERF)

# Unset the horizontal grid on which the simulation data were stored - this will be
# retrieved from the "lat" and "lon" variables in the concentration data
hrz_grid_chi = None
# Exposure (concentration or UV) data is organized into a dictionary. Each entry
# corresponds to one scenario; each scenario contains data for each exposure factor.
chi_data = {}
for src, f in sim_data.items():
    temp_data = {}
    f_full = os.path.join(base_dir,f)
    nc = netCDF4.Dataset(f_full,'r')
    try:
        for field in ['PM25']: #,'MDA8-O3']: -EXB:Removed O3 variable-
            if field not in nc.variables:
                print(field)
                print(nc)
                raise ValueError('Field {} not found in file {}'.format(field,f_full))
            temp_data[field] = nc[field][0,0,...].copy()
            if hrz_grid_chi is None:
                if is_GCHP:
                    hrz_grid_chi = cstools.extract_grid(nc)
                else:	
                    hrz_grid_chi = latlontools.gen_grid(lat_stride=0.5, lon_stride=5/8, half_polar=True,center_180=True,lon_range=[-85, -30],lat_range=[ -60, 15])	
    finally:
        nc.close()
    # Store in a dictionary
    chi_data[src] = {'PM25': temp_data['PM25'],
                  #   'O3_MDA8_ANN': 1.0e9 * temp_data['MDA8-O3'], -EXB:removed line bc we dont have O3 data-
                     'grid': hrz_grid_chi}

# Run the code
output_data = mortality.calc_mortality(chi_data, ref_map, common_chi_grid, ERF_set, GPW_dir, WHO_dir, n_rand, rng_opts={'rng_name': rng_name, 'random_seed': rng_seed},mort_preprocessor_dir=mort_data_dir,n_cpus=n_cpus)
pickle.dump(output_data, open(output_file,'wb'))
# Show results
for scen_name in chi_data.keys():
    if scen_name == ref_map:
        continue
    for erf in ERF_set:
        d_exp = np.nansum(output_data[erf.exposure_factor][erf.name][scen_name]['delta_exp'])
        aff_pop = np.nansum(output_data[erf.exposure_factor][erf.name]['affected_pop'])
        base_inc = np.nansum(output_data[erf.exposure_factor][erf.name]['base_morts'])
        print('ERF {:40s} due to {:15s}: {:9.0f} morts due to a change of {:12.0f} person-chi over {:12.0f} people (pop-exp: {:6.3f}, base inc: {:6.0f} x1,000)'.format(
            erf.name,scen_name,
            np.nansum(np.nanmean(output_data[erf.exposure_factor][erf.name][scen_name]['morts_vs_ref'],axis=-1)),
            d_exp,aff_pop,d_exp/aff_pop,base_inc/1000.0))
